import torchvision
from torch import nn

class ResNet(nn.Module):
    def __init__(self, num_classes=2,model =torchvision.models.resnet50(pretrained=True)):
        super(ResNet, self).__init__()
        inchannel = model.fc.in_features
        model.fc = nn.Sequential(nn.Linear(inchannel, num_classes),
                                  nn.Sigmoid())
        # model.fc = nn.Linear(inchannel, num_classes)
        self.model = model

    def forward(self, x):

        return self.model(x)